Skip to content

[core] fix group offloading when using torchao#13276

Open
sayakpaul wants to merge 8 commits intomainfrom
fix-torchao-groupoffloading
Open

[core] fix group offloading when using torchao#13276
sayakpaul wants to merge 8 commits intomainfrom
fix-torchao-groupoffloading

Conversation

@sayakpaul
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul commented Mar 17, 2026

What does this PR do?

Fix offloading when using TorchAO. This assumes that the underlying quantization tensor class implements pinning properly. But that's not something we can do in TorchAO, anyway.

The benefit of this is that many new releases benefit from quantization schemes robustly implemented and tested in TorchAO. But quantization alone rarely helps, we need offloading too. Many large models need group offloading (overlapping compute with data transfer).

Problem

Group offloading moves parameters between CPU and GPU by reassigning param.data:

param.data = source_tensor.to(device) 

This works for regular tensors but breaks for TorchAO quantized tensors.

TorchAO tensors are special instances that store their actual data in internal attributes (e.g., .qdata, .scale), not in the standard tensor storage. The .data assignment replaces the
outer wrapper storage but leaves these internal attributes on the original device, causing a device mismatch at compute time.

A further subtlety: accessing .data on a wrapper subclass parameter returns a new wrapper object each time, so mutating attributes on param.data doesn't persist either.

This PR

~.data approach~ (not pursued)

For TorchAO tensors, instead of reassigning data, we update the internal tensor attributes directly on the parameter object itself:

# Before (broken for TorchAO tensors)                                                                                                    
param.data = source_tensor.to(device)                                                                                                    
                                                                                                                                         
# After                                                                                                                                  
moved = source_tensor.to(device)                                                                                                         
if _is_torchao_tensor(param):                                                                                                            
    for attr in tensor_data_names:  # e.g. ["qdata", "scale"]                                                                            
        setattr(param, attr, getattr(moved, attr))                                                                                       
else:                                                                                                                                    
    param.data = moved     

For TorchAO tensors, param.data = source_tensor.to(device) doesn't work because _make_wrapper_subclass tensors store their actual data in internal attributes (.qdata, .scale, etc.), and the .data setter only replaces the outer wrapper storage.

We use two strategies depending on the code path:

Onload — torch.utils.swap_tensors, which swaps the full tensor contents in-place:

moved = source_tensor.to(device)
if _is_torchao_tensor(param):
   torch.utils.swap_tensors(param, moved)
else:
   param.data = moved

Offload (with stream) — setattr to copy internal tensor references without mutating the cached CPU copy:

if _is_torchao_tensor(param):
    for attr in tensor_data_names:  # e.g. ["qdata", "scale"]
        setattr(param, attr, getattr(cpu_cached_copy, attr))
else:
    param.data = cpu_cached_copy

swap_tensors can't be used for the stream offload path because it's bidirectional — it would put CUDA data into the cached CPU copy, corrupting it for the next onload cycle.

Related issue: pytorch/ao#4088.

Happens with nightlies as well.

Code to test: https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33#file-check_torchao_offload_compile-py (run with --quantize, --group-offload; and potentially with --full-compile).

Nice results (with quantization + group offloading + full compile):

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:32<00:00,  8.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.18s/it]

Needs the nightlies (of both PyTorch and TorchAO) for testing.

Important

While this PR executes the TorchAO-specific changes, I think we could refactor group offloading-related utilities to rely on swap_tensors, instead of .data as it is considered to be a private API. Separate PR, of course.

Will update tests later.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean the to op is not implemented properly for torchao tensors?

if you have a minimal repro, we might be able to fix I think

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
cpu_copy = p.data.cpu()
cuda_copy = cpu_copy.to("cuda")
p.data = cuda_copy

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the proper way to do this is:

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict
import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
# cpu_copy = p.data.cpu()
cpu_copy = p.cpu()
cuda_copy = cpu_copy.to("cuda")
# p.data = cuda_copy
torch.utils.swap_tensors(linear.weight, cuda_copy)

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

parameter.data is not a recommended API. and linear.weight is also no longer a nn.Parameter after quantization, it's a different tensor subclass (nn.Parameter is also a tensor subclass: https://github.com/pytorch/pytorch/blob/e9ebbd3bee0761eb9d93b53f4a80d3afa2cc46f8/torch/nn/parameter.py#L30).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?

Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this does not seem like a fix? Your snippet mentions linear(x) also fails because of the runtime error. Elaborate?

this is the proper way to move device for a tensor subclass instance I think. please ignore comments, that was copied from your original example. this runs on my side.

Also, how do I best implement it in the context of the error and the diffusers code? I provided the minimal snippet for your convenience, but it doesn't serve the use case. We need to be able to fix it in the context of the use case.

basically we should not be using parameter.data = parameter.data.to("cuda") for quantized weights, but use swap_tensors instead.

we have to go through all linear modules in the model, and use swap_tensor to change device:

for n, m in model.named_modules():
    if isinstance(m, nn.Linear):
        torch.utils.swap_tensors(m.weight, m.weight.to("cuda"))

for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses

@sayakpaul sayakpaul requested a review from jerryzh168 March 23, 2026 05:30
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could potentially implement the record_stream op as a torch_function op in the torchao tensor subclasses as well I think so that you can do:

tensor.record_stream(default_stream) directly.

also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?

Copy link
Copy Markdown
Member Author

@sayakpaul sayakpaul Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor.record_stream(default_stream) directly.

Suggestion looks great. But I guess that will take some work on your end to ship. Maybe we can add a comment about it here and revisit when you land it?

also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?

Wouldn't mind refactoring it from tensor.data.record_stream(default_stream) pattern but we couldn't find out other solutions when we started working on it last year. Separate PR perhaps?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to check in another PR

for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly for this one I'm wondering if it would make sense to implement some copy op in torchao tensor subclasses, also cpu_param_dict can store the torchao tensor subclass instances directly as well, instead of looking into implementation details

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed that would be great!

logger = get_logger(__name__) # pylint: disable=invalid-name


def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while most of the torchao tensor subclasses are developed on top of TorchAOBaseTensor. it's not a requirement to use it. practically this should work for most of the use case but it's not 100% guaranteed

I feel ideally / long term, we can refactor all uses of parameter.data to just operate on parameter itself (if it works)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love that but sadly that's not the case currently as we cannot always control implementation details from external dependencies.

@sayakpaul
Copy link
Copy Markdown
Member Author

@jerryzh168 thanks for sharing further thoughts here! I appreciate them.

While I agree with the comments on long-term vision, I don't think we can assume / control underlying implementation details coming from external dependencies (such as TorchAO and other quantization backends).

Hence, being explicit about the control paths and keeping them separate feels like a fair compromise (compromise because of the increased cyclomatic complexity coming from the conditionals).

But as we go ahead and implement several aspects in TorchAO (record stream, copy, etc.). I think we can work together to reflect them within Diffusers.

WDYT?

@jerryzh168
Copy link
Copy Markdown

@sayakpaul yeah current fix looks good, I'll try to add these things in torchao and let you know, then you can gradually refactor / improve these. no pressure on refactoring all parameter.data usages, but we can at least have torchao path do things properly I think.

@sayakpaul sayakpaul changed the title [WIP] [core] fix group offloading when using torchao [core] fix group offloading when using torchao Mar 26, 2026
@sayakpaul
Copy link
Copy Markdown
Member Author

@asomoza could you check this PR out?

@sayakpaul sayakpaul requested a review from DN6 March 26, 2026 06:00
@sayakpaul sayakpaul added the roadmap Add to current release roadmap label Mar 26, 2026
@sayakpaul sayakpaul marked this pull request as ready for review March 26, 2026 06:00
vkuzo added a commit to pytorch/ao that referenced this pull request Mar 27, 2026
Summary:

Adds memory pinning support for `MXTensor` and `NVFP4Tensor`, this is
important for supporting offloading in `diffusers`.

Fixes the torchao side of #4026

Still requires huggingface/diffusers#13276 to
land for things to work e2e

Test Plan:

e2e: makes
https://gist.github.com/vkuzo/9ca863e559eb8af18f6e6afb079bf74f work with
mxfp8 and nvfp4,
requires huggingface/diffusers#13276 (not yet
landed)

local: the new tests added in this PR pass
ghstack-source-id: ff0b1a2
ghstack-comment-id: 4142562671
Pull-Request: #4192
@sayakpaul
Copy link
Copy Markdown
Member Author

@claude could you review this PR?

@claude
Copy link
Copy Markdown

claude bot commented Mar 27, 2026

Claude finished @sayakpaul's task in 2m 29s —— View job


PR Review: fix group offloading when using torchao

  • Reading review rules and context
  • Reading changed file
  • Analyzing changes
  • Posting review

Overall, the approach is sound and well-documented. The two-strategy split (swap_tensors for onload, setattr for stream-offload) is clearly reasoned and matches the constraints. A few correctness issues worth addressing:

Bug: param.data.to(...) used for TorchAO in the no-stream offload path

_offload_to_memory (lines 321–329) uses param.data.to(...) and buffer.data.to(...) for TorchAO tensors in the no-stream path:

if _is_torchao_tensor(param):
    moved = param.data.to(self.offload_device, non_blocking=False)  # <-- bug
    _swap_torchao_tensor(param, moved)

This is the same problem the PR fixes elsewhere. Per the PR description, param.data on a _make_wrapper_subclass tensor returns a new wrapper each time without the internal attributes (.qdata, .scale). So moved would be an incomplete wrapper with internal attributes still on the original device, and swapping it in would be incorrect. The correct call is param.to(self.offload_device, non_blocking=False) directly — this goes through TorchAO's own to() dispatch which does move the internal attributes. Same issue applies to the buffer block immediately below. Fix this →

Concern: _init_cpu_param_dict also uses .data.cpu() for TorchAO tensors

Line 182 (and 185, 189, 192):

cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()

At init time the params are typically already on CPU, so param.data.cpu() is likely a no-op and the reference is correct. However:

  • param.data on a TorchAO tensor returns a new wrapper every time, so the stored value in cpu_param_dict may not be the canonical TorchAO object with its internal attributes accessible.
  • pin_memory() may not be implemented for TorchAO tensor subclasses. The PR description acknowledges this ("This assumes that the underlying quantization tensor class implements pinning properly"), but it could silently produce an incorrect cpu_param_dict entry rather than a clear failure. Worth adding a warning or comment here when _is_torchao_tensor is true and pinning is attempted. Fix this →

Omission: disk offload path not updated

_onload_from_disk (lines 249–261) and _offload_to_disk (line 294) still use bare .data = assignment and don't apply any TorchAO handling. This is presumably out of scope (safetensors can't directly serialize TorchAO subclasses), but it would be good to add an explicit check or error/warning if a TorchAO tensor reaches those code paths, rather than silently producing incorrect behaviour.

Minor: _swap_torchao_tensor is a one-liner wrapper

def _swap_torchao_tensor(param, source):
    torch.utils.swap_tensors(param, source)

The docstring has real value. But since the function adds no logic, it may be clearer to either inline the call (with a comment) or rename it to something more descriptive of the intent (e.g. _move_torchao_tensor_inplace). Minor preference — not a blocker.

Minor: attribute names fetched from source in _restore_torchao_tensor

for attr_name in _get_torchao_inner_tensor_names(source):
    setattr(param, attr_name, getattr(source, attr_name))

Getting names from source is fine since param and source are copies of the same tensor type. But if param previously held CUDA tensors for those same attributes (before the restore), those references are overwritten without explicit cleanup. Python's GC will handle this, but GPU memory may not be freed immediately — a potential peak-memory concern under tight GPU budgets. Not a correctness bug, just worth being aware of.


Summary: The two-strategy design is correct in principle and the documentation is excellent. The main actionable issue is the param.data.to(...) call for TorchAO in the no-stream _offload_to_memory path — it should be param.to(...) to actually move internal attributes. The _init_cpu_param_dict issue is worth a look too. Tests (noted as TODO by the author) should cover both stream and no-stream paths for TorchAO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

roadmap Add to current release roadmap

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

3 participants